library(here)
library(tidyverse)
library(scales)
library(redist)
library(sf)
library(patchwork)

source(here("R/00_custom_functions.R"))

path = here("data/sims_race_sum.rds")

is_ref = function(x) str_detect(x, "cd") | str_detect(x, "<init") | str_detect(x, "sld_up")

# Load plans and simulations -----------------------------------------------------
if (!file.exists(path)) {
    pa_shp = read_rds(here("data/PA/pa_shp.rds"))
    pa_vtd_orig = read_rds(here("data/PA/pa_vtd_orig.rds"))
    pa_shp = left_join(pa_shp, pa_vtd_orig, by="precinct")
    class(pa_shp) = c("redist_map", class(pa_shp))

    sc_shp = read_rds(here("data/SC/sc_shp.rds"))
    class(sc_shp) = c("redist_map", class(sc_shp))

    nc_shp = read_rds(here("data/NC/nc_shp.rds"))
    class(nc_shp) = c("redist_map", class(nc_shp))

    ms_shp = read_rds(here("data/MS/ms.Rds"))
    class(ms_shp) = c("redist_map", class(ms_shp))

    clean_plans = function(pl) {
        as_tibble(pl) %>%
            `attr<-`("plans", NULL) %>%
            `attr<-`("prec_pop", NULL) %>%
            `attr<-`("wgt", NULL)
    }
    tally_pa = function(pl) {
        mutate(pl, grp = group_frac(pa_shp, pop - pop_white, pop)) %>%
            number_by(grp) %>%
            clean_plans()
    }
    tally_sc = function(pl) {
        mutate(pl, grp = group_frac(sc_shp, pop_black, pop)) %>%
            number_by(grp) %>%
            clean_plans()
    }
    tally_nc = function(pl) {
        mutate(pl, grp = group_frac(nc_shp, pop_black, pop)) %>%
            number_by(grp) %>%
            clean_plans()
    }
    tally_ms = function(pl) {
        mutate(pl, grp = group_frac(ms_shp, pop_black, pop)) %>%
            number_by(grp) %>%
            clean_plans()
    }

    pa_plans = list(
        das04 = read_rds(here("data/PA/sim_das04_cty_001_10k.rds")) %>% tally_pa(),
        das12 = read_rds(here("data/PA/sim_das12_cty_001_10k.rds")) %>% tally_pa(),
        das19 = read_rds(here("data/PA/sim_das19_cty_001_10k.rds")) %>% tally_pa(),
        orig = read_rds(here("data/PA/sim_orig_cty_001_10k.rds")) %>% tally_pa()
    )

    sc_cd_plans = list(
        das04 = read_rds(here("data/SC/sim/CD/plans_da04.rds")) %>% tally_sc(),
        das12 = read_rds(here("data/SC/sim/CD/plans_da12.rds")) %>% tally_sc(),
        das19 = read_rds(here("data/SC/sim/CD/plans_da19.rds")) %>% tally_sc(),
        orig = read_rds(here("data/SC/sim/CD/plans_orig.rds")) %>% tally_sc()
    )

    sc_hd_plans = list(
        das04 = read_rds(here("data/SC/sim/HD_ms-parallel/plans_da04.rds")) %>% tally_sc(),
        das12 = read_rds(here("data/SC/sim/HD_ms-parallel/plans_da12.rds")) %>% tally_sc(),
        das19 = read_rds(here("data/SC/sim/HD_ms-parallel/plans_da19.rds")) %>% tally_sc(),
        orig = read_rds(here("data/SC/sim/HD_ms-parallel/plans_orig.rds")) %>% tally_sc()
    )

    nc_cd_plans = list(
        das04 = read_rds(here("data/NC/sim/CD/plans_da04.rds")) %>% tally_nc(),
        das12 = read_rds(here("data/NC/sim/CD/plans_da12.rds")) %>% tally_nc(),
        das19 = read_rds(here("data/NC/sim/CD/plans_da19.rds")) %>% tally_nc(),
        orig = read_rds(here("data/NC/sim/CD/plans_orig.rds")) %>% tally_nc()
    )

    ms_plans = list(
        das04 = read_rds(here("data/MS/sim/redist_plans_ms_v4.rds")) %>% tally_ms(),
        das12 = read_rds(here("data/MS/sim/redist_plans_ms_v12.rds")) %>% tally_ms(),
        das19 = read_rds(here("data/MS/sim/redist_plans_ms_v19.rds")) %>% tally_ms(),
        orig = read_rds(here("data/MS/sim/redist_plans_ms_cen.rds")) %>% tally_ms()
    )

    rm(pa_shp, nc_shp, sc_shp, ms_shp)

    # Plot --------------------------------------------------------------------

    pcts = bind_rows(
        pa_cd = bind_rows(pa_plans, .id = "source"),
        sc_cd = bind_rows(sc_cd_plans, .id = "source"),
        nc_cd = bind_rows(nc_cd_plans, .id = "source"),
        sc_hd = bind_rows(sc_hd_plans, .id = "source"),
        ms_sd = bind_rows(ms_plans, .id = "source"),
        .id="state"
    )
    rm(pa_plans, sc_hd_plans, sc_cd_plans, nc_cd_plans, ms_plans)

    write_rds(pcts, path, compress="xz")
} else {
    pcts = read_rds(path)
}


pcts_grp0 = pcts %>%
    filter(district == as.integer(district)) %>%
    group_by(state, source, district) %>%
    mutate(ref_grp = grp[is_ref(draw)][1]) %>%
    filter(!is_ref(draw)) %>%
    ungroup()


# group together SC HDs, and restack
pcts_sc_hd <- pcts_grp0 %>%
    filter(state == "sc_hd") %>%
    mutate(district_new = 1 + as.integer((district-1)/5)*5)
pcts_ms_sd <- pcts_grp0 %>%
    filter(state == "ms_sd") %>%
    mutate(district_new = 1 + as.integer((district-1)/4)*4)

pcts_grp <- bind_rows(
    filter(pcts_grp0, !state %in% c("sc_hd", "ms_sd")),
    pcts_sc_hd,
    pcts_ms_sd
    ) %>%
    mutate(district = coalesce(district_new, district))

# Summary statistics ---
pcts_box = pcts_grp %>%
    group_by(state, source, district) %>%
    summarize(med = median(grp - ref_grp),
              q1 = quantile(grp - ref_grp, 0.25),
              q3 = quantile(grp - ref_grp, 0.75),
              low = min(grp - ref_grp),
              high = max(grp - ref_grp),
              .groups = "drop")
rm(pcts)

# patch layout
#' @param data table with all summary stats
#' @param sources which sources to show
patch_boxes_race <- function(data = pcts_box, sources = c("Census 2010", "DAS-12.2", "DAS-4.5")) {
    p_pa    = plot_boxes(data, geo = "pa_cd", title = "Pennsylvania Congressional", show = sources) + labs(x = "Districts, ordered by Black share (minority share in PA)")
    p_sc_cd = plot_boxes(data, geo = "sc_cd", title = "South Carolina Congressional", show = sources) + labs(y = NULL)
    p_nc_cd = plot_boxes(data, geo = "nc_cd", title = "North Carolina Congressional", show = sources)
    p_ms_sd = plot_boxes(data, geo = "ms_sd", grouped = c(by = 3, max = 52),  title = "Mississippi State Senate", show = sources) +
        theme(axis.text.x = element_text(angle = 30, hjust = 1))
    p_sc_hd = plot_boxes(data, geo = "sc_hd", grouped = c(by = 4, max = 124), title = "South Carolina State House", show = sources) +
        theme(axis.text.x = element_text(angle = 30, hjust = 1))

    layout =
        "
         AAABB
         CCCCC
         DDDDD
         EEEEE
        "
    p_nc_cd + p_sc_cd + p_sc_hd + p_ms_sd + p_pa +
        plot_layout(design = layout, guides = "collect") &
        theme(legend.position = "bottom")
}


# Generate plots (functions in 00_)
gg_12_race <- patch_boxes_race(pcts_box, sources = c("Census 2010", "DAS-12.2", "DAS-4.5"))
gg_19_race <- patch_boxes_race(pcts_box, sources = c("Census 2010", "DAS-19.61"))


# Save plots
ggsave("figs/race_boxplots.pdf", gg_12_race, width = 6.5, height = 8.25)
ggsave("figs/race_boxplots_DAS19.pdf", gg_19_race, width = 6.5, height = 8.25)
